1 module hip.hiprenderer.backend.metal.mtlshader;
2 
3 version(AppleOS):
4 
5 import hip.hiprenderer;
6 import hip.console.log;
7 import metal;
8 import metal.metal;
9 import hip.hiprenderer.backend.metal.mtlrenderer;
10 import hip.hiprenderer.backend.metal.mtltexture;
11 import core.int128;
12 
13 
14 class HipMTLFragmentShader : FragmentShader
15 {
16     this(MTLDevice device){}
17     string shaderSource;
18     override string getDefaultFragment()
19     {
20         return string.init; // TODO: implement
21     }
22 
23     override string getFrameBufferFragment(){return "";}
24     override string getGeometryBatchFragment(){return "";}
25     override string getSpriteBatchFragment(){return "";}
26     override string getBitmapTextFragment(){return "";}
27 }
28 
29 struct BufferedMTLBuffer
30 {
31     MTLBuffer[] buffer;
32     ShaderVariablesLayout layout;
33     uint currentBuffer;
34 
35     MTLBuffer getBuffer() => buffer[currentBuffer];
36     void reset(){currentBuffer = 0;}
37 }
38 
39 class HipMTLVertexShader : VertexShader
40 {
41     this(MTLDevice device){}
42     string shaderSource;
43     override string getDefaultVertex()
44     {
45         return string.init; // TODO: implement
46     }
47 
48     override string getFrameBufferVertex()
49     {
50         return import("shaders/metal/framebuffer.metal");
51     }
52 
53     override string getGeometryBatchVertex()
54     {
55         return import("shaders/metal/geometrybatch.metal");
56     }
57 
58     override string getSpriteBatchVertex()
59     {
60         return import("shaders/metal/spritebatch.metal");
61     }
62     override string getBitmapTextVertex()
63     {
64         return import("shaders/metal/bitmaptext.metal");
65     }
66 }
67 
68 MTLBlendOperation fromHipBlendEquation(HipBlendEquation eq)
69 {
70     final switch(eq) with(HipBlendEquation)
71     {
72         case DISABLED: return MTLBlendOperation.Add;
73         case ADD: return MTLBlendOperation.Add;
74         case SUBTRACT: return MTLBlendOperation.Subtract;
75         case REVERSE_SUBTRACT: return MTLBlendOperation.ReverseSubtract;
76         case MIN: return MTLBlendOperation.Min;
77         case MAX: return MTLBlendOperation.Max;
78     }
79 }
80 
81 MTLBlendFactor fromHipBlendFunction(HipBlendFunction fn)
82 {
83     final switch(fn) with(HipBlendFunction)
84     {
85         case ZERO: return MTLBlendFactor.Zero;
86         case ONE: return MTLBlendFactor.One;
87         case SRC_COLOR: return MTLBlendFactor.SourceColor;
88         case ONE_MINUS_SRC_COLOR: return MTLBlendFactor.OneMinusSourceColor;
89         case SRC_ALPHA: return MTLBlendFactor.SourceAlpha;
90         case ONE_MINUS_SRC_ALPHA: return MTLBlendFactor.OneMinusSourceAlpha;
91         case DST_COLOR: return MTLBlendFactor.DestinationColor;
92         case ONE_MINUS_DST_COLOR: return MTLBlendFactor.OneMinusDestinationColor;
93         case DST_ALPHA: return MTLBlendFactor.DestinationAlpha;
94         case ONE_MINUS_DST_ALPHA: return MTLBlendFactor.OneMinusDestinationAlpha;
95         case CONSTANT_COLOR: return MTLBlendFactor.Source1Color;
96         case ONE_MINUS_CONSTANT_COLOR: return MTLBlendFactor.OneMinusSource1Color;
97         case CONSTANT_ALPHA: return MTLBlendFactor.Source1Alpha;
98         case ONE_MINUS_CONSTANT_ALPHA: return MTLBlendFactor.OneMinusSource1Alpha;
99     }
100 }
101 
102 class HipMTLShaderProgram : ShaderProgram
103 {
104     MTLLibrary library;
105     MTLFunction vertexShaderFunction;
106     MTLFunction fragmentShaderFunction;
107     BufferedMTLBuffer* uniformBufferVertex;
108     BufferedMTLBuffer* uniformBufferFragment;
109 
110 
111 
112     MTLRenderPipelineDescriptor pipelineDescriptor;
113     MTLRenderPipelineState pipelineState;
114     HipBlendFunction blendSrc, blendDst;
115     HipBlendEquation blendEq;
116     this()
117     {
118         pipelineDescriptor = MTLRenderPipelineDescriptor.alloc.initialize;
119     }
120 
121     void createInputLayout(MTLDevice device, MTLVertexDescriptor descriptor)
122     {
123         if(pipelineState !is null)
124         {
125            pipelineState.release(); 
126         }
127         NSError err;
128         pipelineDescriptor.vertexDescriptor = descriptor;
129         pipelineState = device.newRenderPipelineStateWithDescriptor(pipelineDescriptor, &err);
130         if(err !is null || pipelineState is null)
131         {
132             import hip.error.handler;
133             ErrorHandler.showErrorMessage("Creating Input Layout",  "Could not create RenderPipelineState");
134             err.print();
135         }
136     }
137 }
138 
139 
140 
141 __gshared HipMTLShaderProgram boundShader;
142 
143 class HipMTLShader : IShader
144 {
145     MTLDevice device;
146     HipMTLRenderer mtlRenderer;
147 
148     this(MTLDevice device, HipMTLRenderer mtlRenderer)
149     {
150         this.device = device;
151         this.mtlRenderer = mtlRenderer;
152     }
153 
154     VertexShader createVertexShader(){return new HipMTLVertexShader(device);}
155     FragmentShader createFragmentShader(){return new HipMTLFragmentShader(device);}
156     ShaderProgram createShaderProgram(){return new HipMTLShaderProgram();}
157     bool compileShader(FragmentShader fs, string shaderSource)
158     {
159         (cast(HipMTLFragmentShader)fs).shaderSource = shaderSource;
160         return true;
161     }
162 
163     bool compileShader(VertexShader vs, string shaderSource)
164     {
165         (cast(HipMTLVertexShader)vs).shaderSource = shaderSource;
166         return true;
167     }
168 
169     bool linkProgram(ref ShaderProgram program, VertexShader vs, FragmentShader fs)
170     {
171         HipMTLShaderProgram p = cast(HipMTLShaderProgram)program;
172         HipMTLVertexShader v = cast(HipMTLVertexShader)vs;
173         HipMTLFragmentShader f = cast(HipMTLFragmentShader)fs;
174 
175         string shaderSource = v.shaderSource~f.shaderSource;
176         scope(exit)
177         {
178             import core.memory;
179             GC.free(cast(void*)shaderSource.ptr);
180         }
181 
182         NSError err;
183         MTLCompileOptions opts = MTLCompileOptions.alloc.initialize;
184         ///Macros
185         opts.preprocessorMacros = cast(NSDictionary)(["ARGS_TIER2": 0].ns);
186         
187         p.library = device.newLibraryWithSource(shaderSource.ns, opts, &err);
188 
189         if(p.library is null || err !is null)
190         {
191             loglnError("Could not compile shader.");
192             err.print();
193             return false;
194         }
195         p.fragmentShaderFunction = p.library.newFunctionWithName("fragment_main".ns);
196         if(p.fragmentShaderFunction is null)
197         {
198             loglnError("fragment_main() not found.");
199             return false;
200         }
201         p.vertexShaderFunction = p.library.newFunctionWithName("vertex_main".ns);
202         if(p.vertexShaderFunction is null)
203         {
204             loglnError("vertex_main() not found.");
205             return false;
206         }
207 
208         p.pipelineDescriptor.label = "HipremeShader".ns;
209         p.pipelineDescriptor.vertexFunction = p.vertexShaderFunction;
210         p.pipelineDescriptor.fragmentFunction = p.fragmentShaderFunction;
211         p.pipelineDescriptor.colorAttachments[0].pixelFormat = MTLPixelFormat.BGRA8Unorm_sRGB;
212         p.pipelineDescriptor.depthAttachmentPixelFormat = MTLPixelFormat.Depth32Float_Stencil8;
213         p.pipelineDescriptor.stencilAttachmentPixelFormat = MTLPixelFormat.Depth32Float_Stencil8;
214         
215         return true;
216     }
217 
218     bool setShaderVar(ShaderVar* sv, ShaderProgram prog, void* value)
219     {
220         switch(sv.type) with(UniformType)
221         {
222             case texture_array:
223             {
224                 import hip.util.algorithm;
225                 IHipTexture[] textures = *cast(IHipTexture[]*)value;
226                 HipMTLShaderVarTexture tempTex;
227                 foreach(size_t i, HipMTLTexture tex; textures.map((IHipTexture itex) => cast(HipMTLTexture)itex.getBackendHandle()))
228                 {
229                     tempTex.textures[i] = tex.texture;
230                     tempTex.samplers[i] = tex.sampler;
231                 }
232                 sv.setBlackboxed(tempTex);
233                 return true;
234             }
235             default: return false;
236         }
237     }
238     void setBlending(ShaderProgram prog, HipBlendFunction src, HipBlendFunction dest, HipBlendEquation eq)
239     {
240         HipMTLShaderProgram p = cast(HipMTLShaderProgram)prog;
241         p.blendSrc = src;
242         p.blendDst = dest;
243         p.blendEq = eq;
244 
245         MTLBlendFactor mtlSrc = src.fromHipBlendFunction;
246         MTLBlendFactor mtlDest = dest.fromHipBlendFunction;
247         MTLBlendOperation mtlOp = eq.fromHipBlendEquation;
248         p.pipelineDescriptor.colorAttachments[0].blendingEnabled = eq != HipBlendEquation.DISABLED;
249         p.pipelineDescriptor.colorAttachments[0].rgbBlendOperation = mtlOp;
250         p.pipelineDescriptor.colorAttachments[0].alphaBlendOperation = mtlOp;
251         p.pipelineDescriptor.colorAttachments[0].sourceRGBBlendFactor = mtlSrc;
252         p.pipelineDescriptor.colorAttachments[0].destinationRGBBlendFactor = mtlDest;
253         p.pipelineDescriptor.colorAttachments[0].sourceAlphaBlendFactor = mtlSrc;
254         p.pipelineDescriptor.colorAttachments[0].destinationAlphaBlendFactor = mtlDest;
255 
256         if(p.pipelineState !is null)
257             p.createInputLayout(device, p.pipelineDescriptor.vertexDescriptor);
258     }
259 
260     void bind(ShaderProgram program)
261     {
262         HipMTLShaderProgram mtlShader = cast(HipMTLShaderProgram)program;
263         if(mtlShader.pipelineState !is null)
264         {
265             mtlRenderer.getEncoder.setRenderPipelineState(mtlShader.pipelineState);
266             if(mtlShader.uniformBufferVertex)
267                 mtlRenderer.getEncoder.setVertexBuffer(mtlShader.uniformBufferVertex.getBuffer, 0, 0);
268             if(mtlShader.uniformBufferFragment)
269                 mtlRenderer.getEncoder.setFragmentBuffer(mtlShader.uniformBufferFragment.getBuffer, 0, 0);
270             boundShader = mtlShader;
271         }
272     }
273 
274     void unbind(ShaderProgram program)
275     {
276         // encoder.setRenderPipelineState(null);
277         mtlRenderer.getEncoder.setVertexBuffer(null, 0, 0);
278         mtlRenderer.getEncoder.setFragmentBuffer(null, 0, 0);
279         if(boundShader is program) boundShader = null;
280     }
281 
282     void sendVertexAttribute(uint layoutIndex, int valueAmount, uint dataType, bool normalize, uint stride, int offset)
283     {
284         
285     }
286 
287     int getId(ref ShaderProgram prog, string name)
288     {
289         return int.init; // TODO: implement
290     }
291 
292     void deleteShader(FragmentShader* fs){}
293     void deleteShader(VertexShader* vs){}
294 
295     private MTLBuffer getNewMTLBuffer(ShaderVariablesLayout layout)
296     {
297         return device.newBuffer(layout.getLayoutSize(), MTLResourceOptions.DefaultCache);
298     }
299     void createVariablesBlock(ref ShaderVariablesLayout layout)
300     {
301         MTLBuffer buffer = getNewMTLBuffer(layout);
302         HipMTLShaderProgram s = cast(HipMTLShaderProgram)(layout.getShader()).shaderProgram;
303         BufferedMTLBuffer* buffered; 
304         layout.setAdditionalData(buffered = new BufferedMTLBuffer([buffer], layout), true);
305         final switch(layout.shaderType)
306         {
307             case ShaderTypes.VERTEX:
308                 s.uniformBufferVertex = buffered;
309                 break;
310             case ShaderTypes.FRAGMENT:
311                 s.uniformBufferFragment = buffered;
312                 break;
313             case ShaderTypes.GEOMETRY:
314             case ShaderTypes.NONE:
315                 break;
316         }
317     }
318     void sendVars(ref ShaderProgram prog, ShaderVariablesLayout[string] layouts)
319     {
320         import core.stdc.string;
321 
322         HipMTLShaderProgram mtlShader = cast(HipMTLShaderProgram)prog;
323         foreach(layout; layouts)
324         {
325             BufferedMTLBuffer* bufferedUniformBuffer = cast(BufferedMTLBuffer*)layout.getAdditionalData();
326             MTLBuffer uniformBuffer = bufferedUniformBuffer.getBuffer;
327             memcpy(uniformBuffer.contents, layout.getBlockData, layout.getLayoutSize);
328             bufferedUniformBuffer.currentBuffer++;
329             if(bufferedUniformBuffer.currentBuffer >= bufferedUniformBuffer.buffer.length)
330                 bufferedUniformBuffer.buffer~= getNewMTLBuffer(layout);
331         }
332     }
333 
334     void bindArrayOfTextures(ref ShaderProgram prog, IHipTexture[] textures, string varName)
335     {
336         __gshared MTLTexture[] mtlTextures;
337         __gshared MTLSamplerState[] mtlSamplers;
338         if(textures.length > mtlTextures.length)
339         {
340             import hip.util.memory;
341             if(mtlTextures !is null)
342             {
343                 free(mtlTextures.ptr);
344                 free(mtlSamplers.ptr);   
345             }
346             mtlTextures = allocSlice!MTLTexture(textures.length);
347             mtlSamplers = allocSlice!MTLSamplerState(textures.length);
348         }
349 
350         foreach(i; 0..textures.length)
351         {
352             HipMTLTexture hMtl = cast(HipMTLTexture)textures[i].getBackendHandle();
353             mtlTextures[i] = hMtl.texture;
354             mtlSamplers[i] = hMtl.sampler;
355         }
356 
357         mtlRenderer.getEncoder.setFragmentSamplerStates(mtlSamplers.ptr, NSRange(0, textures.length));
358         mtlRenderer.getEncoder.setFragmentTextures(mtlTextures.ptr, NSRange(0, textures.length));
359 
360     }
361 
362     void dispose(ref ShaderProgram p)
363     {
364         HipMTLShaderProgram shader = cast(HipMTLShaderProgram)p;
365         foreach(BufferedMTLBuffer* buff; [shader.uniformBufferFragment, shader.uniformBufferVertex])
366         {
367             foreach(MTLBuffer mtlbuffer; buff.buffer)
368                 mtlbuffer.release();
369         }
370     }
371     override void onRenderFrameEnd(ShaderProgram p)
372     {
373         HipMTLShaderProgram shader = cast(HipMTLShaderProgram)p;
374         shader.uniformBufferFragment.reset;
375         shader.uniformBufferVertex.reset;
376     }
377 }